{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "grBmytrShbUE"
      },
      "source": [
        "# High-performance Simulation with Kubernetes\n",
        "\n",
        "This tutorial will describe how to set up high-performance simulation using a\n",
        "TFF runtime running on Kubernetes. The model is the same as in the previous\n",
        "tutorial, **High-performance simulations with TFF**. The only difference is that\n",
        "here we use a worker pool instead of a local executor.\n",
        "\n",
        "This tutorial refers to Google Cloud's [GKE](https://cloud.google.com/kubernetes-engine/) to create the Kubernetes cluster,\n",
        "but all the steps after the cluster is created can be used with any Kubernetes\n",
        "installation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SyXVaj0dknQw"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/high_performance_simulation_with_kubernetes\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/high_performance_simulation_with_kubernetes.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/master/docs/tutorials/high_performance_simulation_with_kubernetes.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yiq_MY4LopET"
      },
      "source": [
        "## Launch the TFF Workers on GKE\n",
        "\n",
        "> **Note:** This tutorial assumes the user has an existing GCP project.\n",
        "\n",
        "### Create a Kubernetes Cluster\n",
        "\n",
        "The following step only needs to be done once. The cluster can be re-used for future workloads.\n",
        "\n",
        "Follow the GKE instructions to [create a container cluster](https://cloud.google.com/kubernetes-engine/docs/tutorials/hello-app#step_4_create_a_container_cluster). The rest of this tutorial assumes that the cluster is named `tff-cluster`, but the actual name isn't important.\n",
        "Stop following the instructions when you get to \"*Step 5: Deploy your application*\".\n",
        "\n",
        "### Deploy the TFF Worker Application\n",
        "\n",
        "The commands to interact with GCP can be run [locally](https://cloud.google.com/kubernetes-engine/docs/tutorials/hello-app#option_b_use_command-line_tools_locally) or in the [Google Cloud Shell](https://cloud.google.com/shell/). We recommend the Google Cloud Shell since it doesn't require additional setup.\n",
        "\n",
        "1. Run the following command to launch the Kubernetes application.\n",
        "\n",
        "```\n",
        "$ kubectl create deployment tff-workers --image=gcr.io/tensorflow-federated/remote-executor-service:latest\n",
        "```\n",
        "\n",
        "2. Add a load balancer for the application.\n",
        "\n",
        "```\n",
        "$ kubectl expose deployment tff-workers --type=LoadBalancer --port 80 --target-port 8000\n",
        "```\n",
        "\n",
        "> **Note:** This exposes your deployment to the internet and is for demo\n",
        "purposes only. For production use, a firewall and authentication are strongly\n",
        "recommended."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WK4ohHUZvVVc"
      },
      "source": [
        "Look up the IP address of the loadbalancer on the Google Cloud Console. You'll need it later to connect the training loop to the worker app."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Lq8r5uaT2rB"
      },
      "source": [
        "### (Alternately) Launch the Docker Container Locally\n",
        "\n",
        "```\n",
        "$ docker run --rm -p 8000:8000 gcr.io/tensorflow-federated/remote-executor-service:latest\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_zFenI3IPpgI"
      },
      "source": [
        "## Set Up TFF Environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ke7EyuvG0Zyn"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            ""
          ]
        }
      ],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow-federated-nightly\n",
        "!pip install --quiet --upgrade nest-asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dFkcJZAojZDm"
      },
      "source": [
        "## Define the Model to Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J0Qk0sCDZUQR"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "import time\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "source, _ = tff.simulation.datasets.emnist.load_data()\n",
        "\n",
        "\n",
        "def map_fn(example):\n",
        "  return collections.OrderedDict(\n",
        "      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])\n",
        "\n",
        "\n",
        "def client_data(n):\n",
        "  ds = source.create_tf_dataset_for_client(source.client_ids[n])\n",
        "  return ds.repeat(10).batch(20).map(map_fn)\n",
        "\n",
        "\n",
        "train_data = [client_data(n) for n in range(10)]\n",
        "input_spec = train_data[0].element_spec\n",
        "\n",
        "\n",
        "def model_fn():\n",
        "  model = tf.keras.models.Sequential([\n",
        "      tf.keras.layers.InputLayer(input_shape=(784,)),\n",
        "      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),\n",
        "      tf.keras.layers.Softmax(),\n",
        "  ])\n",
        "  return tff.learning.from_keras_model(\n",
        "      model,\n",
        "      input_spec=input_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
        "\n",
        "\n",
        "trainer = tff.learning.build_federated_averaging_process(\n",
        "    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))\n",
        "\n",
        "\n",
        "def evaluate(num_rounds=10):\n",
        "  state = trainer.initialize()\n",
        "  for round in range(num_rounds):\n",
        "    t1 = time.time()\n",
        "    state, metrics = trainer.next(state, train_data)\n",
        "    t2 = time.time()\n",
        "    print('Round {}: loss {}, round time {}'.format(round, metrics.loss, t2 - t1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x5OhgAp7jrNI"
      },
      "source": [
        "## Set Up the Remote Executors\n",
        "\n",
        "By default, TFF executes all computations locally. In this step we tell TFF to connect to the Kubernetes services we set up above. Be sure to copy the IP address of your service here."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sXSLXwcdciYm"
      },
      "outputs": [],
      "source": [
        "import grpc\n",
        "\n",
        "ip_address = '0.0.0.0'  #@param {type:\"string\"}\n",
        "port = 80  #@param {type:\"integer\"}\n",
        "\n",
        "channels = [grpc.insecure_channel(f'{ip_address}:{port}') for _ in range(10)]\n",
        "\n",
        "tff.backends.native.set_remote_execution_context(channels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bEgpmgSRktJY"
      },
      "source": [
        "## Run Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mw92IA6_Zrud"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Round 0: loss 4.370407581329346, round time 4.201097726821899\n",
            "Round 1: loss 4.1407670974731445, round time 3.3283166885375977\n",
            "Round 2: loss 3.865147590637207, round time 3.098310947418213\n",
            "Round 3: loss 3.534019708633423, round time 3.1565616130828857\n",
            "Round 4: loss 3.272688388824463, round time 3.175067663192749\n",
            "Round 5: loss 2.935391664505005, round time 3.008434534072876\n",
            "Round 6: loss 2.7399251461029053, round time 3.31435227394104\n",
            "Round 7: loss 2.5054931640625, round time 3.4411356449127197\n",
            "Round 8: loss 2.290508985519409, round time 3.158798933029175\n",
            "Round 9: loss 2.1194536685943604, round time 3.1348156929016113\n"
          ]
        }
      ],
      "source": [
        "evaluate()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "high_performance_simulation_with_kubernetes.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
